{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RL Exercise 3 - Custom Environments and Reward Shaping\n",
    "\n",
    "**GOAL:** The goal of this exercise is to demonstrate how to adapt your own problem to use RLlib.\n",
    "\n",
    "To understand how to use **RLlib**, see the documentation at http://rllib.io.\n",
    "\n",
    "RLlib is not only easy to use in simulated benchmarks but also in the real-world. Here, we will cover two important concepts: how to create your own Markov Decision Process abstraction, and how to shape the reward of your environment so make your agent more effective. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install -U ray[rllib]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "import gym\n",
    "from gym import spaces\n",
    "import numpy as np\n",
    "import test_exercises\n",
    "\n",
    "import ray\n",
    "from ray.rllib.agents.ppo import PPOTrainer, DEFAULT_CONFIG\n",
    "\n",
    "ray.init(ignore_reinit_error=True, log_to_driver=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Different Spaces"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The first thing to do when formulating an RL problem is to specify the dimensions of your observation space and action space. Abstractions for these are provided in ``gym``. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### **Exercise 1:** Match different actions to their corresponding space.\n",
    "\n",
    "The purpose of this exercise is to familiarize you with different Gym spaces. For example:\n",
    "\n",
    "    discrete = spaces.Discrete(10)\n",
    "    print(\"Random sample of this space: \", [discrete.sample() for i in range(4)])\n",
    "\n",
    "Use `help(spaces)` or `help([specific space])` (i.e., `help(spaces.Discrete)`) for more info."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "action_space_map = {\n",
    "    \"discrete_10\": spaces.Discrete(10),\n",
    "    \"box_1\": spaces.Box(0, 1, shape=(1,)),\n",
    "    \"box_3x1\": spaces.Box(-2, 2, shape=(3, 1)),\n",
    "    \"multi_discrete\": spaces.MultiDiscrete([ 5, 2, 2, 4 ])\n",
    "}\n",
    "\n",
    "action_space_jumble = {\n",
    "    \"discrete_10\": 1,\n",
    "    \"CHANGE_ME\": np.array([0, 0, 0, 2]),\n",
    "    \"CHANGE_ME\": np.array([[-1.2657754], [-1.6528835], [ 0.5982418]]),\n",
    "    \"CHANGE_ME\": np.array([0.89089584]),\n",
    "}\n",
    "\n",
    "\n",
    "for space_id, state in action_space_jumble.items():\n",
    "    assert action_space_map[space_id].contains(state), (\n",
    "        \"Looks like {} to {} is matched incorrectly.\".format(space_id, state))\n",
    "    \n",
    "print(\"Success!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## **Exercise 2**: Setting up a custom environment with rewards\n",
    "\n",
    "We'll setup an `n-Chain` environment, which presents moves along a linear chain of states, with two actions:\n",
    "\n",
    "     (0) forward, which moves along the chain but returns no reward\n",
    "     (1) backward, which returns to the beginning and has a small reward\n",
    "\n",
    "The end of the chain, however, presents a large reward, and by moving 'forward', at the end of the chain this large reward can be repeated.\n",
    "\n",
    "#### Step 1: Implement ``ChainEnv._setup_spaces``\n",
    "\n",
    "We'll use a `spaces.Discrete` action space and observation space. Implement `ChainEnv._setup_spaces` so that `self.action_space` and `self.obseration_space` are proper gym spaces.\n",
    "  \n",
    "1. Observation space is an integer in ``[0 to n-1]``.\n",
    "2. Action space is an integer in ``[0, 1]``.\n",
    "\n",
    "For example:\n",
    "\n",
    "```python\n",
    "    self.action_space = spaces.Discrete(2)\n",
    "    self.observation_space = ...\n",
    "```\n",
    "\n",
    "You should see a message indicating tests passing when done correctly!\n",
    "\n",
    "#### Step 2: Implement a reward function.\n",
    "\n",
    "When `env.step` is called, it returns a tuple of ``(state, reward, done, info)``. Right now, the reward is always 0. \n",
    "\n",
    "Implement it so that \n",
    "\n",
    "1. ``action == 1`` will return `self.small_reward`.\n",
    "2. ``action == 0`` will return 0 if `self.state < self.n - 1`.\n",
    "3. ``action == 0`` will return `self.large_reward` if `self.state == self.n - 1`.\n",
    "\n",
    "You should see a message indicating tests passing when done correctly. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ChainEnv(gym.Env):\n",
    "    \n",
    "    def __init__(self, env_config = None):\n",
    "        env_config = env_config or {}\n",
    "        self.n = env_config.get(\"n\", 20)\n",
    "        self.small_reward = env_config.get(\"small\", 2)  # payout for 'backwards' action\n",
    "        self.large_reward = env_config.get(\"large\", 10)  # payout at end of chain for 'forwards' action\n",
    "        self.state = 0  # Start at beginning of the chain\n",
    "        self._horizon = self.n\n",
    "        self._counter = 0  # For terminating the episode\n",
    "        self._setup_spaces()\n",
    "    \n",
    "    def _setup_spaces(self):\n",
    "        ##############\n",
    "        # TODO: Implement this so that it passes tests\n",
    "        self.action_space = None\n",
    "        self.observation_space = None\n",
    "        ##############\n",
    "\n",
    "    def step(self, action):\n",
    "        assert self.action_space.contains(action)\n",
    "        if action == 1:  # 'backwards': go back to the beginning, get small reward\n",
    "            ##############\n",
    "            # TODO 2: Implement this so that it passes tests\n",
    "            reward = -1\n",
    "            ##############\n",
    "            self.state = 0\n",
    "        elif self.state < self.n - 1:  # 'forwards': go up along the chain\n",
    "            ##############\n",
    "            # TODO 2: Implement this so that it passes tests\n",
    "            reward = -1\n",
    "            self.state += 1\n",
    "        else:  # 'forwards': stay at the end of the chain, collect large reward\n",
    "            ##############\n",
    "            # TODO 2: Implement this so that it passes tests\n",
    "            reward = -1\n",
    "            ##############\n",
    "        self._counter += 1\n",
    "        done = self._counter >= self._horizon\n",
    "        return self.state, reward, done, {}\n",
    "\n",
    "    def reset(self):\n",
    "        self.state = 0\n",
    "        self._counter = 0\n",
    "        return self.state\n",
    "    \n",
    "# Tests here:\n",
    "test_exercises.test_chain_env_spaces(ChainEnv)\n",
    "test_exercises.test_chain_env_reward(ChainEnv)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Let's now train a policy on the environment and evaluate this policy on our environment.\n",
    "\n",
    "You'll see that despite an extremely high reward, the policy has barely explored the state space."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer_config = DEFAULT_CONFIG.copy()\n",
    "trainer_config['num_workers'] = 1\n",
    "trainer_config[\"train_batch_size\"] = 400\n",
    "trainer_config[\"sgd_minibatch_size\"] = 64\n",
    "trainer_config[\"num_sgd_iter\"] = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = PPOTrainer(trainer_config, ChainEnv);\n",
    "for i in range(20):\n",
    "    print(\"Training iteration {}...\".format(i))\n",
    "    trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = ChainEnv({})\n",
    "state = env.reset()\n",
    "\n",
    "done = False\n",
    "max_state = -1\n",
    "cumulative_reward = 0\n",
    "\n",
    "while not done:\n",
    "    action = trainer.compute_action(state)\n",
    "    state, reward, done, results = env.step(action)\n",
    "    max_state = max(max_state, state)\n",
    "    cumulative_reward += reward\n",
    "\n",
    "print(\"Cumulative reward you've received is: {}. Congratulations!\".format(cumulative_reward))\n",
    "print(\"Max state you've visited is: {}. This is out of {} states.\".format(max_state, env.n))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Exercise 3: Shaping the reward to encourage proper behavior.\n",
    "\n",
    "You'll see that despite an extremely high reward, the policy has barely explored the state space. This is often the situation - where the reward designed to encourage a particular solution is suboptimal, and the behavior created is unintended.\n",
    "\n",
    "#### Modify `ShapedChainEnv.step` to provide a reward that encourages the policy to traverse the chain (not just stick to 0). Do not change the behavior of the environment (the action -> state behavior should be the same).\n",
    "\n",
    "You can change the reward to be whatever you wish."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ShapedChainEnv(ChainEnv):\n",
    "    def step(self, action):\n",
    "        assert self.action_space.contains(action)\n",
    "        if action == 1:  # 'backwards': go back to the beginning\n",
    "            reward = -1\n",
    "            self.state = 0\n",
    "        elif self.state < self.n - 1:  # 'forwards': go up along the chain\n",
    "            reward = -1\n",
    "            self.state += 1\n",
    "        else:  # 'forwards': stay at the end of the chain\n",
    "            reward = -1\n",
    "        self._counter += 1\n",
    "        done = self._counter >= self._horizon\n",
    "        return self.state, reward, done, {}\n",
    "    \n",
    "test_exercises.test_chain_env_behavior(ShapedChainEnv)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluate `ShapedChainEnv` by running the cell below.\n",
    "\n",
    "This trains PPO on the new env and counts the number of states seen."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = PPOTrainer(trainer_config, ShapedChainEnv);\n",
    "for i in range(20):\n",
    "    print(\"Training iteration {}...\".format(i))\n",
    "    trainer.train()\n",
    "\n",
    "env = ShapedChainEnv({})\n",
    "\n",
    "max_states = []\n",
    "\n",
    "for i in range(5):\n",
    "    state = env.reset()\n",
    "    done = False\n",
    "    max_state = -1\n",
    "    cumulative_reward = 0\n",
    "    while not done:\n",
    "        action = trainer.compute_action(state)\n",
    "        state, reward, done, results = env.step(action)\n",
    "        max_state = max(max_state, state)\n",
    "        cumulative_reward += reward\n",
    "    max_states += [max_state]\n",
    "\n",
    "print(\"Cumulative reward you've received is: {}!\".format(cumulative_reward))\n",
    "print(\"Max state you've visited is: {}. This is out of {} states.\".format(np.mean(max_states), env.n))\n",
    "assert (env.n - np.mean(max_states)) / env.n < 0.2, \"This policy did not traverse many states.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
